package drds.data_propagate.driver;

import drds.data_propagate.driver.packets.HeaderPacket;
import drds.data_propagate.driver.packets.client.command_packet.QueryPacket;
import drds.data_propagate.driver.packets.server.*;
import drds.data_propagate.driver.socket.SocketChannel;
import drds.data_propagate.driver.utils.PacketManager;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * 默认输出的数据编码为UTF-8，如有需要请正确转码
 */
public class QueryExecutor {

    private SocketChannel socketChannel;

    public QueryExecutor(Connector connector) throws IOException {
        if (!connector.isConnected()) {
            throw new IOException("should execute connector.connect() first");
        }
        this.socketChannel = connector.getSocketChannel();
    }

    public QueryExecutor(SocketChannel socketChannel) {
        this.socketChannel = socketChannel;
    }

    /**
     * (Result Set Header Packet) the number of columns <br>
     * (Field Packets) column descriptors <br>
     * (EOF Packet) marker: end of Field Packets <br>
     * (Row Data Packets) row contents <br>
     * (EOF Packet) marker: end of Data Packets
     *
     * @param queryString
     * @return
     * @throws IOException
     */
    public ResultSetPacket query(String queryString) throws IOException {
        QueryPacket queryPacket = new QueryPacket();
        queryPacket.setQueryString(queryString);
        byte[] bodyBytes = queryPacket.toBytes();
        PacketManager.writeBodyBytes(socketChannel, bodyBytes);
        //
        byte[] packetBodyBytes = readNextPacket();
        if (packetBodyBytes[0] < 0) {
            ErrorPacket errorPacket = new ErrorPacket();
            errorPacket.fromBytes(packetBodyBytes);
            throw new IOException(errorPacket + "\n with command: " + queryString);
        }
        //
        ResultSetHeaderPacket resultSetHeaderPacket = new ResultSetHeaderPacket();
        resultSetHeaderPacket.fromBytes(packetBodyBytes);//getColumnCount
        //
        List<ColumnPacket> columnPacketList = new ArrayList<ColumnPacket>();
        for (int i = 0; i < resultSetHeaderPacket.getColumnCount(); i++) {
            ColumnPacket columnPacket = new ColumnPacket();
            columnPacket.fromBytes(readNextPacket());
            columnPacketList.add(columnPacket);
        }
        //
        readEofPacket();
        //
        List<RowDataPacket> rowDataPacketList = new ArrayList<RowDataPacket>();
        while (true) {
            packetBodyBytes = readNextPacket();
            if (packetBodyBytes[0] == -2) {//end
                break;
            }
            RowDataPacket rowDataPacket = new RowDataPacket();
            rowDataPacket.fromBytes(packetBodyBytes);
            rowDataPacketList.add(rowDataPacket);
        }

        ResultSetPacket resultSetPacket = new ResultSetPacket();
        resultSetPacket.getColumnPacketList().addAll(columnPacketList);
        for (RowDataPacket rowDataPacket : rowDataPacketList) {
            resultSetPacket.getColumnNameList().addAll(rowDataPacket.getColumnNameList());
        }
        resultSetPacket.setSourceSocketAddress(socketChannel.getRemoteSocketAddress());
        return resultSetPacket;
    }

    public List<ResultSetPacket> querys(String queryString) throws IOException {
        QueryPacket queryPacket = new QueryPacket();
        queryPacket.setQueryString(queryString);
        byte[] bodyBytes = queryPacket.toBytes();
        PacketManager.writeBodyBytes(socketChannel, bodyBytes);
        //
        List<ResultSetPacket> resultSetPacketList = new ArrayList<ResultSetPacket>();
        boolean moreResult = true;
        while (moreResult) {
            byte[] packetBodyBytes = readNextPacket();
            if (packetBodyBytes[0] < 0) {
                ErrorPacket errorPacket = new ErrorPacket();
                errorPacket.fromBytes(packetBodyBytes);
                throw new IOException(errorPacket + "\n with command: " + queryString);
            }
            //
            ResultSetHeaderPacket resultSetHeaderPacket = new ResultSetHeaderPacket();
            resultSetHeaderPacket.fromBytes(packetBodyBytes);
            //
            List<ColumnPacket> columnPacketList = new ArrayList<ColumnPacket>();
            for (int i = 0; i < resultSetHeaderPacket.getColumnCount(); i++) {
                ColumnPacket columnPacket = new ColumnPacket();
                columnPacket.fromBytes(readNextPacket());
                columnPacketList.add(columnPacket);
            }
            //
            moreResult = readEofPacket();
            //
            List<RowDataPacket> rowDataPacketList = new ArrayList<RowDataPacket>();
            while (true) {
                packetBodyBytes = readNextPacket();
                if (packetBodyBytes[0] == -2) {//end
                    break;
                }
                RowDataPacket rowDataPacket = new RowDataPacket();
                rowDataPacket.fromBytes(packetBodyBytes);
                rowDataPacketList.add(rowDataPacket);
            }
            //
            ResultSetPacket resultSetPacket = new ResultSetPacket();
            resultSetPacket.getColumnPacketList().addAll(columnPacketList);
            for (RowDataPacket rowDataPacket : rowDataPacketList) {
                resultSetPacket.getColumnNameList().addAll(rowDataPacket.getColumnNameList());
            }
            resultSetPacket.setSourceSocketAddress(socketChannel.getRemoteSocketAddress());
            resultSetPacketList.add(resultSetPacket);
        }

        return resultSetPacketList;
    }

    private boolean readEofPacket() throws IOException {
        byte[] bytes = readNextPacket();
        EofPacket eofPacket = new EofPacket();
        eofPacket.fromBytes(bytes);
        if (bytes[0] != -2) {
            throw new IOException("EOF Packet is expected, but packet with field_count=" + bytes[0] + " is found.");
        }
        return (eofPacket.statusFlag & 0x0008) != 0;
    }

    protected byte[] readNextPacket() throws IOException {
        HeaderPacket headerPacket = PacketManager.readHeader(socketChannel, 4);
        return PacketManager.readBytes(socketChannel, headerPacket.getPacketBodyLength());
    }
}
